"""
This file contains specific functions for computing losses on the RetinaNet
file
"""

import math
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable

from ..utils import cat

from maskrcnn_benchmark import _C
from maskrcnn_benchmark.modeling.matcher import Matcher
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist


class Clip(Function):
    @staticmethod
    def forward(ctx, x, a, b):
        return x.clamp(a, b)

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        return grad_output, None, None


clip = Clip.apply


def negative_bag_loss(logits):
    return torch.sum(
        logits ** 2 * F.binary_cross_entropy(logits, torch.zeros_like(logits), reduction='none')
    )


def positive_bag_loss(logits):
    weight = 1 / clip(1 - logits, 1e-12, None)
    weight /= weight.sum()
    bag_prob = (weight * logits).sum()
    return F.binary_cross_entropy(bag_prob, torch.ones_like(bag_prob))


class MultiAnchorLossComputation(object):
    """
    This class computes the RetinaNet loss.
    """

    def __init__(self, cfg, box_coder):
        """
        Arguments:
            proposal_matcher (Matcher)
            box_coder (BoxCoder)
        """
        self.box_coder = box_coder
        self.num_classes = cfg.RETINANET.NUM_CLASSES - 1
        self.iou_threshold = cfg.MULTIANCHOR.IOU_THRESHOLD
        self.pre_anchor_topk = cfg.MULTIANCHOR.PRE_ANCHOR_TOPK
        self.smooth_l1_loss_param = (cfg.MULTIANCHOR.BBOX_REG_WEIGHT, cfg.MULTIANCHOR.BBOX_REG_BETA)
        self.bbox_threshold = (cfg.MULTIANCHOR.BBOX_THRESHOLD_L, cfg.MULTIANCHOR.BBOX_THRESHOLD_H)
        self.focal_loss_alpha = cfg.MULTIANCHOR.FOCAL_LOSS_ALPHA

        self.positive_bag_loss_func = positive_bag_loss
        self.negative_bag_loss_func = negative_bag_loss

    def __call__(self, anchors, box_cls, box_regression, targets):
        """
        Arguments:
            anchors (list[BoxList])
            objectness (list[Tensor])
            box_regression (list[Tensor])
            targets (list[BoxList])

        Returns:
            objectness_loss (Tensor)
            box_loss (Tensor
        """
        anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors]
        box_cls_flattened = []
        box_regression_flattened = []
        # for each feature level, permute the outputs to make them be in the
        # same format as the labels. Note that the labels are computed for
        # all feature levels concatenated, so we keep the same representation
        # for the objectness and the box_regression
        for box_cls_per_level, box_regression_per_level in zip(
            box_cls, box_regression
        ):
            N, A, H, W = box_cls_per_level.shape
            C = self.num_classes
            box_cls_per_level = box_cls_per_level.view(N, -1, C, H, W)
            box_cls_per_level = box_cls_per_level.permute(0, 3, 4, 1, 2)
            box_cls_per_level = box_cls_per_level.reshape(N, -1, C)
            box_regression_per_level = box_regression_per_level.view(N, -1, 4, H, W)
            box_regression_per_level = box_regression_per_level.permute(0, 3, 4, 1, 2)
            box_regression_per_level = box_regression_per_level.reshape(N, -1, 4)
            box_cls_flattened.append(box_cls_per_level)
            box_regression_flattened.append(box_regression_per_level)
        # concatenate on the first dimension (representing the feature levels), to
        # take into account the way the labels were generated (with all feature maps
        # being concatenated as well)
        box_cls = cat(box_cls_flattened, dim=1)
        box_regression = cat(box_regression_flattened, dim=1)

        device = box_cls.device
        cls_prob = torch.sigmoid(box_cls)
        box_prob = torch.zeros_like(box_cls)
        box_prob_list = [
            list(torch.unbind(box_prob_pre_img, dim=1)) for box_prob_pre_img in torch.unbind(box_prob, dim=0)
        ]
        positive_numels = []
        positive_losses = [torch.tensor(0., dtype=torch.float, device=device)]
        for img, (anchors_, targets_, cls_prob_, box_regression_) in enumerate(
                zip(anchors, targets, cls_prob, box_regression)
        ):
            match_quality_matrix = boxlist_iou(targets_, anchors_)
            for ind, label in enumerate(targets_.get_field("labels")):
                _, matched = torch.topk(match_quality_matrix[ind, :], self.pre_anchor_topk, dim=0, sorted=False)
                matched_cls_prob = cls_prob_[matched, label - 1]
                matched_numel = matched_cls_prob.size(0)
                if matched_numel == 0:
                    continue
                object_targets = self.box_coder.encode(targets_.bbox[ind, :].unsqueeze(0), anchors_.bbox)
                retinanet_regression_loss = smooth_l1_loss(
                    box_regression_, object_targets, *self.smooth_l1_loss_param
                )
                object_box_iou = regression_target_iou(box_regression_, object_targets)
                object_box_prob = clip(
                    (object_box_iou - self.bbox_threshold[0]) / (self.bbox_threshold[1] - self.bbox_threshold[0]), 0, 1
                ).detach()
                box_prob_list[img][label - 1] = torch.max(box_prob_list[img][label - 1], object_box_prob)

                matched_box_prob = torch.exp(-retinanet_regression_loss)[matched]
                positive_numels.append(
                    matched_numel
                )
                positive_losses.append(
                    self.positive_bag_loss_func(matched_cls_prob * matched_box_prob)
                )
#                 matched_cls_prob = cls_prob_[:, label - 1]
#                 matched_numel = matched_cls_prob.size(0)
#                 if matched_numel == 0:
#                     continue
#                 object_targets = self.box_coder.encode(targets_.bbox[ind, :].unsqueeze(0), anchors_.bbox)
#                 retinanet_regression_loss = smooth_l1_loss(
#                     box_regression_, object_targets, *self.smooth_l1_loss_param
#                 )
#                 object_box_iou = regression_target_iou(box_regression_, object_targets)
#                 object_box_prob = clip(
#                     (object_box_iou - self.bbox_threshold[0]) / (self.bbox_threshold[1] - self.bbox_threshold[0]), 0, 1
#                 ).detach()
#                 box_prob_list[img][label - 1] = torch.max(box_prob_list[img][label - 1], object_box_prob)

#                 matched_box_prob = torch.exp(-retinanet_regression_loss)
#                 positive_numels.append(
#                     self.pre_anchor_topk
#                 )
#                 positive_losses.append(
#                     self.positive_bag_loss_func(
#                         torch.topk(matched_cls_prob * matched_box_prob, self.pre_anchor_topk, dim=0, sorted=False)[0]
#                     )
#                 )

        positive_loss = torch.stack(positive_losses).sum() / max(1, len(positive_numels))

        box_prob = torch.stack([
            torch.stack(box_prob_pre_img, dim=1) for box_prob_pre_img in box_prob_list
        ], dim=0)

        anchor_prob = cls_prob * (1 - box_prob)
#         anchor_prob = torch.topk(
#             anchor_prob.view(-1), min(anchor_prob.numel(), sum(positive_numels)),
#             dim=0, largest=True, sorted=False
#         )[0]

        negative_loss = (
            self.negative_bag_loss_func(anchor_prob) / max(1, sum(positive_numels))
        )

        losses = {
            "loss_retina_positive": positive_loss * self.focal_loss_alpha,
            "loss_retina_negative": negative_loss * (1 - self.focal_loss_alpha),
        }
        return losses


def smooth_l1_loss(pred, target, weight, beta):
    val = target - pred
    abs_val = val.abs()
    smooth_mask = abs_val < beta
    return weight * torch.where(smooth_mask, 0.5 / beta * val ** 2, (abs_val - 0.5 * beta)).sum(dim=1)


def l2_loss(pred, target, weight, beta):
    val = target - pred
    return weight * (0.5 / beta * val ** 2).sum(dim=1)


def regression_target_iou(pred, target):
    x1, y1, w1, h1 = torch.unbind(pred, dim=-1)
    x2, y2, w2, h2 = torch.unbind(target, dim=-1)
    area1, area2 = (w1 + h1).exp(), (w2 + h2).exp()
    w1, h1, w2, h2 = w1.exp() / 2, h1.exp() / 2, w2.exp() / 2, h2.exp() / 2
    x11, y11, x12, y12 = x1 - w1, y1 - h1, x1 + w1, y1 + h1
    x21, y21, x22, y22 = x2 - w2, y2 - h2, x2 + w2, y2 + h2
    inter = (torch.min(x12, x22) - torch.max(x11, x21)) * (torch.min(y12, y22) - torch.max(y11, y21))
    return inter / (area1 + area2 - inter)


def make_multi_anchor_loss_evaluator(cfg, box_coder):
    return MultiAnchorLossComputation(cfg, box_coder)
